{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "77be7dda",
   "metadata": {},
   "source": [
    "## lmi-dist rollingbatch Mixtral-8x7B  deployment guide"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d69a078d",
   "metadata": {},
   "source": [
    "### In this tutorial, you will use lmi-dist backend of Large Model Inference(LMI) DLC to deploy Mixtral-8x7B and run inference with it.\n",
    "\n",
    "Please make sure the following permission granted before running the notebook:\n",
    "\n",
    "* S3 bucket push access\n",
    "* SageMaker access\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bc171a0",
   "metadata": {},
   "source": [
    "### Step 1: Let's bump up SageMaker and import stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdd49db4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install sagemaker --upgrade  --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc3c2465",
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "from sagemaker import Model, image_uris, serializers, deserializers\n",
    "\n",
    "role = sagemaker.get_execution_role()  # execution role for the endpoint\n",
    "sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs\n",
    "region = sess._region_name  # region name of the current SageMaker Studio environment\n",
    "account_id = sess.account_id() "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ab26348",
   "metadata": {},
   "source": [
    "### Step 2: Start preparing model artifacts"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86df559e",
   "metadata": {},
   "source": [
    "In LMI container, we expect some artifacts to help setting up the model\n",
    "\n",
    "* serving.properties (required): Defines the model server settings\n",
    "* model.py (optional): A python file to define the core inference logic\n",
    "* requirements.txt (optional): Any additional pip wheel need to install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19ff0dc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile serving.properties\n",
    "engine=MPI\n",
    "option.model_id=mistralai/Mixtral-8x7B-v0.1\n",
    "option.tensor_parallel_degree=8\n",
    "option.max_rolling_batch_size=32\n",
    "option.rolling_batch=lmi-dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79c21521",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sh\n",
    "mkdir mymodel\n",
    "mv serving.properties mymodel/\n",
    "tar czvf mymodel.tar.gz mymodel/\n",
    "rm -rf mymodel"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a817d0e",
   "metadata": {},
   "source": [
    "### Step 3: Start building SageMaker endpoint"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42b20125",
   "metadata": {},
   "source": [
    "#### Getting the container image URI"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cef43453",
   "metadata": {},
   "source": [
    "See available Large Model Inference DLC's [here](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e61a908b",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_uri = image_uris.retrieve(\n",
    "        framework=\"djl-deepspeed\",\n",
    "        region=sess.boto_session.region_name,\n",
    "        version=\"0.27.0\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96a08056",
   "metadata": {},
   "source": [
    "#### Upload artifact on S3 and create SageMaker model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f903af58",
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_code_prefix = \"large-model-lmi-dist/code\"\n",
    "bucket = sess.default_bucket()  # bucket to house artifacts\n",
    "code_artifact = sess.upload_data(\"mymodel.tar.gz\", bucket, s3_code_prefix)\n",
    "print(f\"S3 Code or Model tar ball uploaded to --- > {code_artifact}\")\n",
    "\n",
    "model = Model(image_uri=image_uri, model_data=code_artifact, role=role)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c04c778f",
   "metadata": {},
   "source": [
    "#### Create SageMaker endpoint with a specified instance type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f27a8df",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "instance_type = \"ml.g5.48xlarge\"\n",
    "endpoint_name = sagemaker.utils.name_from_base(\"lmi-model\")\n",
    "print(f\"endpoint_name: {endpoint_name}\")\n",
    "\n",
    "model.deploy(initial_instance_count=1,\n",
    "             instance_type=instance_type,\n",
    "             endpoint_name=endpoint_name,\n",
    "             container_startup_health_check_timeout=1800\n",
    "            )\n",
    "\n",
    "# our requests and responses will be in json format so we specify the serializer and the deserializer\n",
    "predictor = sagemaker.Predictor(\n",
    "    endpoint_name=endpoint_name,\n",
    "    sagemaker_session=sess,\n",
    "    serializer=serializers.JSONSerializer(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19004557",
   "metadata": {},
   "source": [
    "### Step 4: Run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccbf0046",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.predict(\n",
    "    {\"inputs\": \"The future of Artificial Intelligence is\", \"parameters\": {\"max_new_tokens\":128, \"do_sample\":True}}\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_pytorch_p310",
   "language": "python",
   "name": "conda_pytorch_p310"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
